{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ddfa9ae6-69fe-444a-b994-8c4c5970a7ec",
   "metadata": {},
   "source": [
    "# Project - Airline AI Assistant\n",
    "\n",
    "We'll now bring together what we've learned to make an AI Customer Support assistant for an Airline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b50bbe2-c0b1-49c3-9a5c-1ba7efa2bcb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import json\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import gradio as gr\n",
    "import sqlite3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "747e8786-9da8-4342-b6c9-f5f69c2e22ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialization\n",
    "\n",
    "load_dotenv(override=True)\n",
    "\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "if openai_api_key:\n",
    "    print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
    "else:\n",
    "    print(\"OpenAI API Key not set\")\n",
    "    \n",
    "MODEL = \"gpt-4.1-mini\"\n",
    "openai = OpenAI()\n",
    "\n",
    "DB = \"prices.db\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a521d84-d07c-49ab-a0df-d6451499ed97",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"\"\"\n",
    "You are a helpful assistant for an Airline called FlightAI.\n",
    "Give short, courteous answers, no more than 1 sentence.\n",
    "Always be accurate. If you don't know the answer, say so.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3e8173c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ticket_price(city):\n",
    "    print(f\"DATABASE TOOL CALLED: Getting price for {city}\", flush=True)\n",
    "    with sqlite3.connect(DB) as conn:\n",
    "        cursor = conn.cursor()\n",
    "        cursor.execute('SELECT price FROM prices WHERE city = ?', (city.lower(),))\n",
    "        result = cursor.fetchone()\n",
    "        return f\"Ticket price to {city} is ${result[0]}\" if result else \"No price data available for this city\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03f19289",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_ticket_price(\"Paris\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcfb6523",
   "metadata": {},
   "outputs": [],
   "source": [
    "price_function = {\n",
    "    \"name\": \"get_ticket_price\",\n",
    "    \"description\": \"Get the price of a return ticket to the destination city.\",\n",
    "    \"parameters\": {\n",
    "        \"type\": \"object\",\n",
    "        \"properties\": {\n",
    "            \"destination_city\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"The city that the customer wants to travel to\",\n",
    "            },\n",
    "        },\n",
    "        \"required\": [\"destination_city\"],\n",
    "        \"additionalProperties\": False\n",
    "    }\n",
    "}\n",
    "tools = [{\"type\": \"function\", \"function\": price_function}]\n",
    "tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a2a15d-b559-4844-b377-6bd5cb4949f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def chat(message, history):\n",
    "    history = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n",
    "    messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
    "    response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
    "    return response.choices[0].message.content\n",
    "\n",
    "gr.ChatInterface(fn=chat, type=\"messages\").launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c91d012e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(message, history):\n",
    "    history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n",
    "    messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
    "    response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
    "\n",
    "    while response.choices[0].finish_reason==\"tool_calls\":\n",
    "        message = response.choices[0].message\n",
    "        responses = handle_tool_calls(message)\n",
    "        messages.append(message)\n",
    "        messages.extend(responses)\n",
    "        response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
    "    \n",
    "    return response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "956c3b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_tool_calls(message):\n",
    "    responses = []\n",
    "    for tool_call in message.tool_calls:\n",
    "        if tool_call.function.name == \"get_ticket_price\":\n",
    "            arguments = json.loads(tool_call.function.arguments)\n",
    "            city = arguments.get('destination_city')\n",
    "            price_details = get_ticket_price(city)\n",
    "            responses.append({\n",
    "                \"role\": \"tool\",\n",
    "                \"content\": price_details,\n",
    "                \"tool_call_id\": tool_call.id\n",
    "            })\n",
    "    return responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eca803e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gr.ChatInterface(fn=chat, type=\"messages\").launch()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b369bf10",
   "metadata": {},
   "source": [
    "## A bit more about what Gradio actually does:\n",
    "\n",
    "1. Gradio constructs a frontend Svelte app based on our Python description of the UI\n",
    "2. Gradio starts a server built upon the Starlette web framework listening on a free port that serves this React app\n",
    "3. Gradio creates backend routes for our callbacks, like chat(), which calls our functions\n",
    "\n",
    "And of course when Gradio generates the frontend app, it ensures that the the Submit button calls the right backend route.\n",
    "\n",
    "That's it!\n",
    "\n",
    "It's simple, and it has a result that feels magical."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "863aac34",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "473e5b39-da8f-4db1-83ae-dbaca2e9531e",
   "metadata": {},
   "source": [
    "# Let's go multi-modal!!\n",
    "\n",
    "We can use DALL-E-3, the image generation model behind GPT-4o, to make us some images\n",
    "\n",
    "Let's put this in a function called artist.\n",
    "\n",
    "### Price alert: each time I generate an image it costs about 4 cents - don't go crazy with images!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c27c4ba-8ed5-492f-add1-02ce9c81d34c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some imports for handling images\n",
    "\n",
    "import base64\n",
    "from io import BytesIO\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "773a9f11-557e-43c9-ad50-56cbec3a0f8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def artist(city):\n",
    "    image_response = openai.images.generate(\n",
    "            model=\"dall-e-3\",\n",
    "            prompt=f\"An image representing a vacation in {city}, showing tourist spots and everything unique about {city}, in a vibrant pop-art style\",\n",
    "            size=\"1024x1024\",\n",
    "            n=1,\n",
    "            response_format=\"b64_json\",\n",
    "        )\n",
    "    image_base64 = image_response.data[0].b64_json\n",
    "    image_data = base64.b64decode(image_base64)\n",
    "    return Image.open(BytesIO(image_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d877c453-e7fb-482a-88aa-1a03f976b9e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "image = artist(\"New York City\")\n",
    "display(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "728a12c5-adc3-415d-bb05-82beb73b079b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def talker(message):\n",
    "    response = openai.audio.speech.create(\n",
    "      model=\"gpt-4o-mini-tts\",\n",
    "      voice=\"onyx\",    # Also, try replacing onyx with alloy or coral\n",
    "      input=message\n",
    "    )\n",
    "    return response.content"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bc7580b",
   "metadata": {},
   "source": [
    "## Let's bring this home:\n",
    "\n",
    "1. A multi-modal AI assistant with image and audio generation\n",
    "2. Tool callling with database lookup\n",
    "3. A step towards an Agentic workflow\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b119ed1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(history):\n",
    "    history = [{\"role\":h[\"role\"], \"content\":h[\"content\"]} for h in history]\n",
    "    messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
    "    response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
    "    cities = []\n",
    "    image = None\n",
    "\n",
    "    while response.choices[0].finish_reason==\"tool_calls\":\n",
    "        message = response.choices[0].message\n",
    "        responses, cities = handle_tool_calls_and_return_cities(message)\n",
    "        messages.append(message)\n",
    "        messages.extend(responses)\n",
    "        response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
    "\n",
    "    reply = response.choices[0].message.content\n",
    "    history += [{\"role\":\"assistant\", \"content\":reply}]\n",
    "\n",
    "    voice = talker(reply)\n",
    "\n",
    "    if cities:\n",
    "        image = artist(cities[0])\n",
    "    \n",
    "    return history, voice, image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5846bc77",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_tool_calls_and_return_cities(message):\n",
    "    responses = []\n",
    "    cities = []\n",
    "    for tool_call in message.tool_calls:\n",
    "        if tool_call.function.name == \"get_ticket_price\":\n",
    "            arguments = json.loads(tool_call.function.arguments)\n",
    "            city = arguments.get('destination_city')\n",
    "            cities.append(city)\n",
    "            price_details = get_ticket_price(city)\n",
    "            responses.append({\n",
    "                \"role\": \"tool\",\n",
    "                \"content\": price_details,\n",
    "                \"tool_call_id\": tool_call.id\n",
    "            })\n",
    "    return responses, cities"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e520161",
   "metadata": {},
   "source": [
    "## The 3 types of Gradio UI\n",
    "\n",
    "`gr.Interface` is for standard, simple UIs\n",
    "\n",
    "`gr.ChatInterface` is for standard ChatBot UIs\n",
    "\n",
    "`gr.Blocks` is for custom UIs where you control the components and the callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f250915",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Callbacks (along with the chat() function above)\n",
    "\n",
    "def put_message_in_chatbot(message, history):\n",
    "        return \"\", history + [{\"role\":\"user\", \"content\":message}]\n",
    "\n",
    "# UI definition\n",
    "\n",
    "with gr.Blocks() as ui:\n",
    "    with gr.Row():\n",
    "        chatbot = gr.Chatbot(height=500, type=\"messages\")\n",
    "        image_output = gr.Image(height=500, interactive=False)\n",
    "    with gr.Row():\n",
    "        audio_output = gr.Audio(autoplay=True)\n",
    "    with gr.Row():\n",
    "        message = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
    "\n",
    "# Hooking up events to callbacks\n",
    "\n",
    "    message.submit(put_message_in_chatbot, inputs=[message, chatbot], outputs=[message, chatbot]).then(\n",
    "        chat, inputs=chatbot, outputs=[chatbot, audio_output, image_output]\n",
    "    )\n",
    "\n",
    "ui.launch(inbrowser=True, auth=(\"ed\", \"bananas\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "226643d2-73e4-4252-935d-86b8019e278a",
   "metadata": {},
   "source": [
    "# Exercises and Business Applications\n",
    "\n",
    "Add in more tools - perhaps to simulate actually booking a flight. A student has done this and provided their example in the community contributions folder.\n",
    "\n",
    "Next: take this and apply it to your business. Make a multi-modal AI assistant with tools that could carry out an activity for your work. A customer support assistant? New employee onboarding assistant? So many possibilities! Also, see the week2 end of week Exercise in the separate Notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e795560-1867-42db-a256-a23b844e6fbe",
   "metadata": {},
   "source": [
    "<table style=\"margin: 0; text-align: left;\">\n",
    "    <tr>\n",
    "        <td style=\"width: 150px; height: 150px; vertical-align: middle;\">\n",
    "            <img src=\"../assets/thankyou.jpg\" width=\"150\" height=\"150\" style=\"display: block;\" />\n",
    "        </td>\n",
    "        <td>\n",
    "            <h2 style=\"color:#090;\">I have a special request for you</h2>\n",
    "            <span style=\"color:#090;\">\n",
    "                My editor tells me that it makes a HUGE difference when students rate this course on Udemy - it's one of the main ways that Udemy decides whether to show it to others. If you're able to take a minute to rate this, I'd be so very grateful! And regardless - always please reach out to me at ed@edwarddonner.com if I can help at any point.\n",
    "            </span>\n",
    "        </td>\n",
    "    </tr>\n",
    "</table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
